# 改代码主要来源于慕课网，地址为 https://www.imooc.com/video/14996
# 本节主要讲的是 用决策树进行花瓣分类

from sklearn.datasets import load_iris
from sklearn.cross_validation import train_test_split
import numpy as np
import pandas as pd
from sklearn import tree
from sklearn import metrics

iris = load_iris()

# print(iris["data"])

train_data, test_data, train_target, test_target = train_test_split(iris.data, iris.target, test_size=0.2, random_state=1)

# 建模
clf = tree.DecisionTreeClassifier(criterion="entropy")
clf.fit(train_data, train_target)
y_predict = clf.predict(test_data)
print(y_predict)

# 验证
print(metrics.accuracy_score(y_true=test_target, y_pred=y_predict))
print(metrics.confusion_matrix(y_true=test_target, y_pred=y_predict))

with open("./data/sklearn_02_tree.dot", "w") as fw:
    tree.export_graphviz(clf, out_file=fw)
        
